{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/nano/notebooks/pytorch/tutorial/pytorch_inference_onnx.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Apply ONNXRuntime Acceleration on Inference Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Introduction\n",
    "In this notebook we will describe how to apply ONNXRuntime Acceleration on inference pipeline with the APIs delivered by BigDL-Nano in 4 simple steps."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 0: Prepare Environment\n",
    "Before you start with APIs delivered by bigdl-nano, you have to make sure BigDL-Nano is correctly installed for PyTorch.\n",
    "\n",
    "We recommend to run below commands, especially `source bigdl-nano-init` before jupyter kernel is started, or some of the optimizations may not take effect.\n",
    "\n",
    "```bash\n",
    "# nightly bulit version\n",
    "pip install --pre --upgrade bigdl-nano[pytorch]\n",
    "# set env variables\n",
    "source bigdl-nano-init\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before you start with onnxruntime accelerator, you need to install some onnx packages as follows to set up your environment with ONNXRuntime acceleration.\n",
    "\n",
    "```bash\n",
    "pip install onnx onnxruntime\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Load Data\n",
    "Here we use the Oxford-IIIT Pet Dataset, you could find more information from [here](https://www.robots.ox.ac.uk/~vgg/data/pets/). Note that the dataset is hosted on https://www.robots.ox.ac.uk, and its server sometimes may fail. If you got an `HTTP 500` error when downloading the dataset, it may be because that their server just failed. You could wait for some time and try again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torchvision.io import read_image\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import OxfordIIITPet\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "train_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                      transforms.RandomCrop(224),\n",
    "                                      transforms.RandomHorizontalFlip(),\n",
    "                                      transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n",
    "val_transform = transforms.Compose([transforms.Resize([224, 224]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\n",
    "# Apply data augmentation to the tarin_dataset\n",
    "train_dataset = OxfordIIITPet(root = \"/tmp/data\", transform=train_transform, download=True)\n",
    "val_dataset = OxfordIIITPet(root=\"/tmp/data\", transform=val_transform)\n",
    "# obtain training indices that will be used for validation\n",
    "indices = torch.randperm(len(train_dataset))\n",
    "val_size = len(train_dataset) // 4\n",
    "train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])\n",
    "val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])\n",
    "# prepare data loaders\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=32)\n",
    "\n",
    "DEV_RUN = bool(os.environ.get('DEV_RUN', False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Custom Model\n",
    "Regarding the model, we used pretrained torchvision.models.resnet18. More details, please refer to [here](https://pytorch.org/vision/0.12/generated/torchvision.models.resnet18.html?highlight=resnet18).\n",
    "\n",
    "If you run the notebook on Colab, you may meet `OSError:... undefined symbol...` error regarding `torchtext` when importing `Trainer` from `bigdl.nano.pytorch`. To solve this, you could try to downgrade `torchtext` to the compatible version with `torch`. You could refer to the table [here](https://pypi.org/project/torchtext/) for which verison of `torchtext` you should install.\n",
    "\n",
    "Our nightly-built `bigdl-nano[pytorch]` is currently based on `torch==1.11.0`. So you could run the following command to solve this problem:\n",
    "\n",
    "```bash\n",
    "# Please ignore this if you do not met problem with torchtext\n",
    "!pip install torchtext==0.12.0\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "\n",
      "  | Name  | Type             | Params\n",
      "-------------------------------------------\n",
      "0 | model | ResNet           | 11.2 M\n",
      "1 | loss  | CrossEntropyLoss | 0     \n",
      "-------------------------------------------\n",
      "11.2 M    Trainable params\n",
      "0         Non-trainable params\n",
      "11.2 M    Total params\n",
      "44.782    Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 87/87 [00:42<00:00,  2.07it/s, loss=0.339, v_num=18]  \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([2, 6])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torchvision.models import resnet18\n",
    "from bigdl.nano.pytorch import Trainer\n",
    "model_ft = resnet18(pretrained=True)\n",
    "num_ftrs = model_ft.fc.in_features\n",
    "\n",
    "# Here the size of each output sample is set to 37.\n",
    "model_ft.fc = torch.nn.Linear(num_ftrs, 37)\n",
    "loss_ft = torch.nn.CrossEntropyLoss()\n",
    "optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "# Compile our model with loss function, optimizer.\n",
    "model = Trainer.compile(model_ft, loss_ft, optimizer_ft)\n",
    "trainer = Trainer(max_epochs=5,\n",
    "                  fast_dev_run=DEV_RUN) # Run model quickly in test\n",
    "trainer.fit(model, train_dataloaders=train_dataloader)\n",
    "\n",
    "# Inference/Prediction\n",
    "x = torch.stack([val_dataset[0][0], val_dataset[1][0]])\n",
    "model_ft.eval()\n",
    "y_hat = model_ft(x)\n",
    "y_hat.argmax(dim=1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: ONNXRuntime Acceleration\n",
    "trace your model as an ONNXRuntime model using `InferenceOptimizer.trace()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import InferenceOptimizer\n",
    "ort_model = InferenceOptimizer.trace(model_ft, accelerator=\"onnxruntime\", input_sample=torch.rand(1, 3, 224, 224))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The usage is almost the same with any PyTorch module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 6])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_hat = ort_model(x)\n",
    "y_hat.argmax(dim=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.12 ('nano')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "75c4387adfc215da0f2d9d02c27ad9a4df553a9f0187eec0365fe565a2e50216"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
